"Automatic Waste Sorting using AI"
This notebook is an extension of Waste Sorter by Collindching and parts of code snippets and code cells have been heavily referred with some tweaks from the above mentioned link.
Primary focus of this notebook to extend the Model's performance by improving the accuracy and reducing the misclassified error which was noticed earlier.
This notebook is built on fast.ai v2 library. More information can be found here https://course.fast.ai/. Current version of Collindching's notebook supported fast ai v1 version. So, some tweaks have been done to support that.
Recycling contamination occurs when waste is incorrectly disposed of - like recycling a pizza box with oil on it (compost). Or when waste is correctly disposed of but incorrectly prepared - like recycling unrinsed jam jars.
Contamination is a huge problem in the recycling industry that can be mitigated with automated waste sorting. Just for kicks, I thought I'd try my hand at prototyping an image classifier to classify trash and recyclables - this classifier could have applications in an optical sorting system.
In this project, I'll try to reduce the misclassification error which was noticed earlier (link mentioned above)
We will follow the same prior steps
The below code cell is simply a upgrade step of fastai library, this is done to ensure we have latest fixes in one place. I noticed while plotting top losses of images that there were few empty plots, and this was the quick fix I could find. So, to be on safer side this additional step is performed.
import warnings
warnings.filterwarnings('ignore')
!pip install --upgrade git+https://github.com/fastai/fastai.git
Collecting git+https://github.com/fastai/fastai.git Cloning https://github.com/fastai/fastai.git to /tmp/pip-req-build-s_v1f852 Running command git clone -q https://github.com/fastai/fastai.git /tmp/pip-req-build-s_v1f852 Requirement already satisfied: pip in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (21.1.3) Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (21.3) Requirement already satisfied: fastdownload<2,>=0.0.5 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (0.0.5) Requirement already satisfied: fastcore<1.4,>=1.3.27 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.3.27) Requirement already satisfied: torchvision>=0.8.2 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (0.11.1+cu111) Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (3.2.2) Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.3.5) Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (2.23.0) Requirement already satisfied: pyyaml in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (3.13) Requirement already satisfied: fastprogress>=0.2.4 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.0.0) Requirement already satisfied: pillow>6.0.0 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (7.1.2) Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.0.2) Requirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.4.1) Requirement already satisfied: spacy<4 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (2.2.4) Requirement already satisfied: torch<1.11,>=1.7.0 in /usr/local/lib/python3.7/dist-packages (from fastai==2.5.4) (1.10.0+cu111) Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from fastprogress>=0.2.4->fastai==2.5.4) (1.21.5) Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (0.4.1) Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.1.3) Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.5) Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (3.0.6) Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (7.4.0) Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.6) Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (0.9.0) Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (4.62.3) Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (2.0.6) Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (1.0.0) Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from spacy<4->fastai==2.5.4) (57.4.0) Requirement already satisfied: importlib-metadata>=0.20 in /usr/local/lib/python3.7/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (4.11.0) Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20->catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (3.7.0) Requirement already satisfied: typing-extensions>=3.6.4 in /usr/local/lib/python3.7/dist-packages (from importlib-metadata>=0.20->catalogue<1.1.0,>=0.0.7->spacy<4->fastai==2.5.4) (3.10.0.2) Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (2.10) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (2021.10.8) Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (3.0.4) Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->fastai==2.5.4) (1.24.3) Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (1.3.2) Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (0.11.0) Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (2.8.2) Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->fastai==2.5.4) (3.0.7) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.1->matplotlib->fastai==2.5.4) (1.15.0) Requirement already satisfied: pytz>=2017.3 in /usr/local/lib/python3.7/dist-packages (from pandas->fastai==2.5.4) (2018.9) Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fastai==2.5.4) (3.1.0) Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->fastai==2.5.4) (1.1.0)
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
!pip install -Uqq fastbook
import fastbook
fastbook.setup_book()
from fastbook import *
from fastai.vision.widgets import *
from pathlib import Path
from glob2 import glob
from sklearn.metrics import confusion_matrix
import pandas as pd
import numpy as np
import os
import zipfile as zf
import shutil
import re
import seaborn as sns
First, we need to extract the contents of "dataset-resized.zip".
# Alternatively using the code to work it through, by mounting the google drive..
from google.colab import drive
drive.mount('/content/gdrive')
Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
# Copying and extracting the dataset fetched from the site..
!cp -r gdrive/MyDrive/dataset-resized.zip .
files = zf.ZipFile("dataset-resized.zip",'r')
files.extractall()
files.close()
Once unzipped, the dataset-resized folder has six subfolders:
# A basic sanity check to ensure the waste categories (images) are in place or not.
os.listdir(os.path.join(os.getcwd(),"dataset-resized"))
['plastic', 'trash', 'cardboard', 'metal', 'glass', 'paper', '.DS_Store']
Now that we've extracted the data, I'm going to split images up into train, validation, and test image folders with a 50-25-25 split. First, I'll define some functions that will help me quickly build it. If you're not interested in building the data set, you can just run this ignore it.
?random
f = os.path.join('dataset-resized', 'plastic')
n = len(os.listdir(f))
k = random.sample(list(range(1,n+1)),int(.5*n))
print(k)
[254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 411, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 372, 472, 271, 114, 392, 225, 482, 284, 120, 177, 119, 347, 113, 481, 236, 149, 453, 214, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 460, 365, 257, 445, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 418, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 439, 340, 261, 56, 84, 267, 476, 190, 251, 376, 454, 241, 23, 158, 361, 315, 304, 297, 399, 332, 88, 87, 258, 117, 7, 103, 277, 465, 434, 208, 264, 435, 296, 181, 430, 138, 338, 440, 467, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 292, 437, 362, 405, 212, 249, 183, 397, 178, 1, 276, 373, 320, 314, 170, 235, 308, 443, 471, 91, 282, 300, 93, 47, 283, 131, 17, 37, 43, 9, 232, 8, 144, 128, 447, 57, 95, 356, 429, 36, 86, 82, 349, 442, 366, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 384, 130, 262, 475, 468, 11, 116, 10, 204, 75, 19, 83, 462, 412, 396, 432, 480, 115, 377, 203, 289, 463, 31, 189, 77, 33, 248, 55, 387, 426, 79, 268, 220, 20, 80, 322, 431]
# Cite - https://github.com/collindching/waste-sorter/blob/master/Waste%20sorter.ipynb
## helper functions ##
## splits indices for a folder into train, validation, and test indices with random sampling
## input: folder path
## output: train, valid, and test indices
def split_indices(folder,seed1,seed2):
n = len(os.listdir(folder))
print("Folder path:{}".format(folder))
full_set = list(range(1,n+1))
## train indices
random.seed(seed1)
train = random.sample(list(range(1,n+1)),int(.5*n))
## temp
remain = list(set(full_set)-set(train))
## separate remaining into validation and test
random.seed(seed2)
valid = random.sample(remain,int(.5*len(remain)))
test = list(set(remain)-set(valid))
print("List of indices\n {}.\n.{}\n.{}".format(train, valid, test))
return(train,valid,test)
## gets file names for a particular type of trash, given indices
## input: waste category and indices
## output: file names
def get_names(waste_type,indices):
file_names = [waste_type+str(i)+".jpg" for i in indices]
return(file_names)
## moves group of source files to another folder
## input: list of source files and destination folder
## no output
def move_files(source_files,destination_folder):
for file in source_files:
shutil.move(file,destination_folder)
Next, We will follow the same convention as Imagenet architecture
/data
/train
/cardboard
/glass
/metal
/paper
/plastic
/trash
/valid
/cardboard
/glass
/metal
/paper
/plastic
/trash
/test
Each image file is just the material name and a number (i.e. cardboard1.jpg)
# Removing the folder if existed before..
path = Path(os.getcwd()+"/data")
subset = ['train', 'valid', 'test']
[shutil.rmtree(os.path.join(path, sub_folder)) for sub_folder in subset if os.path.exists(os.path.join(path, sub_folder))]
[None, None, None]
#Cite - https://github.com/collindching/waste-sorter/blob/master/Waste%20sorter.ipynb
## paths will be train/cardboard, train/glass, etc...
subsets = ['train','valid']
waste_types = ['cardboard','glass','metal','paper','plastic','trash']
def verify_source_file(waste_type, dataset_ind):
dataset_names = get_names(waste_type, dataset_ind)
source_files = []
for name in dataset_names:
path = os.path.join(source_folder,name)
if not os.path.exists(path):
continue
source_files.append(path)
return source_files
## create destination folders for data subset and waste type
for subset in subsets:
for waste_type in waste_types:
folder = os.path.join('data',subset,waste_type)
if not os.path.exists(folder):
os.makedirs(folder)
if not os.path.exists(os.path.join('data','test')):
os.makedirs(os.path.join('data','test'))
## move files to destination folders for each waste type
for waste_type in waste_types:
source_folder = os.path.join('dataset-resized',waste_type)
train_ind, valid_ind, test_ind = split_indices(source_folder,1,1)
train_source_files = verify_source_file(waste_type, train_ind)
train_dest = "data/train/"+waste_type
move_files(train_source_files,train_dest)
## move source files to valid
valid_source_files = verify_source_file(waste_type, valid_ind)
valid_dest = "data/valid/"+waste_type
move_files(valid_source_files,valid_dest)
## move source files to test
test_source_files = verify_source_file(waste_type, test_ind)
move_files(test_source_files,"data/test")
Folder path:dataset-resized/cardboard List of indices [69, 292, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 108, 49, 250, 15, 200, 222, 312, 2, 357, 229, 137, 370, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 388, 271, 114, 225, 397, 284, 120, 177, 119, 347, 113, 236, 149, 374, 214, 285, 329, 52, 96, 323, 152, 62, 171, 257, 366, 260, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 207, 213, 89, 188, 281, 192, 45, 362, 261, 56, 84, 267, 372, 190, 251, 375, 241, 23, 158, 304, 297, 316, 88, 87, 258, 117, 7, 103, 277, 324, 383, 208, 264, 358, 181, 354, 138, 299, 3, 197, 263, 67, 266, 106, 219, 29, 247, 187, 336, 393, 212, 249, 183, 327, 178, 1, 170, 235, 365, 379, 91, 93, 47, 399, 17, 37, 43, 9, 232, 8, 144, 128, 402, 283, 205, 160, 48, 326, 75, 331, 262, 41, 66, 136, 244, 44, 169, 70, 166, 276, 76, 378, 180, 83, 342, 122, 30, 332, 80, 99, 306, 391, 204, 396, 287, 28, 65, 291, 265, 54, 338, 367, 210, 6, 58, 380, 102, 38, 10, 185, 42, 115, 294, 130, 174, 110]. .[79, 307, 389, 39, 142, 73, 269, 387, 237, 245, 341, 201, 123, 60, 255, 24, 206, 227, 319, 4, 359, 234, 147, 132, 314, 64, 167, 373, 21, 22, 384, 290, 13, 382, 126, 224, 351, 286, 127, 230, 400, 296, 133, 182, 325, 337, 239, 155, 350, 221, 298, 63, 105, 157, 394, 175, 272, 348, 273, 107, 161, 153, 270, 300, 209, 25, 252, 139, 216, 315, 100, 193, 198, 57, 335, 173, 395, 34, 371, 51, 143, 223, 112, 311, 282, 274, 19, 134, 317, 90, 186, 305, 162, 310, 318, 243, 168, 55, 218, 141, 72] .[11, 20, 26, 27, 31, 32, 35, 36, 40, 46, 50, 59, 68, 71, 74, 77, 78, 81, 82, 85, 86, 92, 94, 95, 97, 101, 104, 109, 116, 121, 124, 129, 135, 140, 145, 148, 150, 151, 154, 159, 164, 165, 172, 176, 179, 184, 189, 191, 194, 199, 203, 211, 215, 220, 226, 228, 233, 238, 240, 248, 253, 268, 275, 279, 280, 288, 289, 293, 295, 308, 309, 313, 320, 321, 322, 328, 330, 339, 340, 343, 344, 345, 346, 349, 353, 355, 356, 360, 361, 363, 364, 368, 369, 376, 377, 381, 385, 386, 398, 401, 403] Folder path:dataset-resized/glass List of indices [69, 292, 434, 411, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 492, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 452, 196, 352, 111, 217, 372, 477, 271, 114, 491, 225, 487, 284, 120, 177, 119, 347, 113, 486, 236, 149, 476, 214, 429, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 465, 365, 257, 449, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 421, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 443, 340, 261, 56, 84, 267, 407, 190, 251, 376, 459, 241, 23, 158, 361, 315, 304, 297, 384, 332, 88, 87, 258, 117, 7, 103, 277, 396, 438, 208, 264, 439, 296, 181, 493, 138, 338, 363, 472, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 494, 441, 418, 408, 212, 249, 183, 400, 178, 1, 276, 364, 320, 314, 170, 235, 308, 447, 463, 91, 282, 300, 93, 47, 283, 489, 17, 37, 43, 9, 232, 8, 144, 128, 355, 57, 95, 359, 433, 36, 86, 82, 375, 446, 369, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 387, 130, 262, 480, 473, 11, 116, 10, 204, 75, 19, 83, 467, 415, 414, 436, 265, 485, 115, 380, 203, 389, 385, 31, 153, 65, 248, 55, 435, 394, 79, 373, 220]. .[73, 307, 440, 420, 403, 40, 136, 68, 270, 402, 234, 244, 339, 193, 413, 110, 54, 268, 26, 460, 431, 199, 224, 322, 483, 405, 4, 362, 230, 142, 379, 484, 124, 317, 60, 166, 461, 24, 25, 470, 293, 20, 469, 354, 121, 221, 382, 427, 289, 122, 453, 227, 478, 298, 126, 175, 401, 351, 409, 238, 154, 426, 215, 299, 336, 58, 101, 330, 155, 479, 172, 273, 416, 274, 349, 102, 159, 148, 316, 272, 353, 201, 345, 27, 252, 132, 209, 468, 94, 185, 295, 189, 50, 451, 275, 63, 85, 286, 424, 186, 269, 412, 471, 32, 161, 313, 92, 475, 356, 437, 21, 107, 398, 367, 279, 399, 180, 388, 444, 6, 194, 442, 223, 211] .[22, 28, 30, 34, 35, 38, 39, 41, 42, 44, 46, 48, 51, 64, 66, 70, 71, 72, 74, 76, 77, 78, 80, 81, 90, 99, 100, 104, 105, 109, 112, 123, 127, 129, 134, 135, 139, 141, 143, 145, 147, 150, 157, 162, 164, 167, 168, 169, 173, 174, 179, 182, 184, 191, 205, 206, 210, 218, 226, 228, 237, 239, 240, 245, 253, 280, 287, 290, 291, 294, 305, 306, 309, 310, 311, 318, 319, 321, 324, 325, 326, 327, 328, 331, 335, 337, 342, 343, 348, 350, 358, 366, 368, 374, 377, 383, 386, 395, 397, 406, 410, 417, 419, 422, 423, 425, 430, 432, 445, 448, 450, 454, 455, 456, 457, 462, 464, 466, 474, 481, 482, 488, 490, 495] Folder path:dataset-resized/metal List of indices [69, 292, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 108, 49, 250, 15, 200, 222, 312, 391, 2, 357, 229, 137, 370, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 395, 271, 114, 225, 404, 284, 120, 177, 119, 347, 113, 236, 149, 380, 214, 285, 329, 52, 96, 323, 152, 62, 171, 257, 372, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 207, 213, 89, 188, 281, 192, 45, 368, 261, 56, 84, 267, 338, 190, 251, 381, 241, 23, 158, 304, 297, 321, 88, 87, 258, 117, 7, 103, 277, 355, 363, 208, 264, 364, 296, 181, 360, 138, 314, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 306, 339, 212, 249, 183, 332, 178, 1, 170, 235, 371, 385, 91, 93, 47, 406, 17, 37, 43, 9, 232, 8, 144, 128, 299, 57, 95, 300, 75, 336, 265, 41, 66, 136, 244, 44, 169, 70, 166, 279, 76, 308, 180, 83, 346, 122, 30, 307, 80, 99, 311, 398, 204, 397, 291, 28, 65, 366, 268, 54, 343, 373, 210, 6, 58, 376, 102, 38, 10, 185, 42, 115, 298, 130, 174, 110, 140, 309]. .[79, 313, 399, 39, 145, 73, 269, 396, 237, 245, 348, 203, 124, 60, 255, 24, 206, 227, 326, 4, 367, 234, 150, 133, 320, 64, 168, 387, 21, 22, 401, 289, 13, 394, 127, 224, 365, 283, 129, 230, 402, 294, 134, 184, 340, 359, 239, 159, 362, 221, 295, 63, 107, 160, 403, 176, 272, 351, 273, 109, 162, 155, 270, 315, 209, 25, 252, 141, 216, 383, 101, 194, 199, 55, 345, 68, 94, 293, 322, 287, 280, 19, 135, 327, 90, 189, 316, 164, 374, 328, 116, 172, 51, 274, 143, 72, 11, 205, 324, 151, 153, 226] .[20, 26, 27, 31, 32, 34, 35, 36, 40, 46, 48, 50, 59, 71, 74, 77, 78, 81, 82, 85, 86, 92, 97, 100, 104, 105, 112, 121, 123, 126, 132, 139, 142, 147, 148, 154, 157, 161, 165, 167, 173, 175, 179, 182, 186, 191, 193, 198, 201, 211, 215, 218, 220, 223, 228, 233, 238, 240, 243, 248, 253, 262, 275, 276, 282, 286, 290, 305, 310, 317, 318, 319, 325, 330, 331, 335, 337, 341, 342, 349, 350, 353, 354, 356, 358, 361, 369, 375, 377, 378, 379, 382, 384, 386, 388, 389, 393, 400, 405, 407, 408, 409, 410] Folder path:dataset-resized/paper List of indices [138, 583, 65, 262, 121, 508, 461, 484, 389, 215, 97, 500, 30, 400, 444, 3, 457, 273, 235, 105, 326, 32, 23, 27, 555, 10, 391, 222, 433, 582, 541, 228, 449, 589, 239, 354, 237, 225, 471, 297, 572, 427, 103, 191, 304, 124, 341, 513, 566, 520, 195, 311, 291, 512, 518, 403, 36, 492, 249, 414, 425, 178, 376, 384, 89, 450, 521, 111, 168, 539, 380, 502, 31, 481, 45, 316, 404, 175, 173, 515, 233, 13, 205, 277, 472, 441, 281, 119, 208, 264, 177, 488, 434, 296, 181, 236, 466, 594, 338, 561, 312, 491, 374, 579, 197, 402, 439, 421, 454, 528, 524, 263, 415, 67, 266, 399, 288, 106, 219, 29, 247, 446, 187, 292, 284, 552, 259, 212, 536, 417, 183, 213, 533, 1, 276, 511, 320, 525, 314, 170, 576, 308, 15, 412, 118, 574, 91, 282, 300, 93, 509, 47, 409, 283, 442, 436, 419, 131, 17, 431, 345, 37, 43, 9, 232, 8, 387, 428, 144, 128, 497, 57, 440, 458, 95, 504, 149, 538, 86, 82, 437, 271, 87, 337, 140, 332, 365, 151, 514, 360, 165, 255, 243, 59, 547, 160, 198, 176, 216, 584, 133, 56, 130, 537, 591, 108, 543, 567, 11, 116, 569, 204, 75, 19, 369, 83, 229, 361, 260, 348, 476, 279, 113, 323, 356, 265, 231, 115, 269, 333, 16, 203, 346, 295, 540, 496, 324, 568, 522, 153, 592, 109, 25, 157, 546, 40, 159, 355, 482, 214, 290, 392, 485, 5, 478, 20, 303, 112, 556, 499, 88, 319, 261, 339, 194, 469, 462, 51, 477, 294, 553, 359, 100, 253, 54, 200, 152, 468, 256, 435, 167, 206, 145, 531, 81, 329, 526, 289, 70, 174, 220, 110, 137, 50, 544, 495, 438]. .[78, 501, 42, 192, 573, 418, 381, 397, 328, 148, 58, 410, 532, 335, 368, 4, 378, 207, 166, 62, 258, 21, 529, 530, 7, 330, 155, 363, 486, 161, 372, 507, 171, 293, 169, 158, 388, 234, 464, 358, 61, 127, 257, 310, 367, 85, 549, 321, 102, 379, 309, 306, 189, 395, 150, 190, 357, 385, 280, 489, 593, 586, 240, 373, 188, 455, 494, 136, 386, 581, 12, 180, 72, 487, 344, 142, 147, 564, 559, 125, 217, 301, 334, 565, 317, 422, 542, 432, 278, 505, 548, 483, 558, 199, 407, 126, 185, 315, 523, 172, 426, 90, 302, 250, 241, 429, 342, 270, 53, 396, 401, 68, 6, 423, 210, 382, 69, 480, 193, 114, 230, 120, 473, 510, 275, 305, 246, 479, 134, 299, 550, 196, 223, 470, 154, 453, 182, 364, 226, 218, 307, 498, 146, 184, 122, 416, 298, 2] .[516, 517, 519, 14, 527, 18, 22, 535, 534, 24, 26, 28, 33, 34, 35, 545, 38, 39, 551, 41, 554, 44, 557, 46, 48, 49, 560, 562, 52, 563, 55, 570, 571, 60, 63, 64, 575, 66, 578, 577, 580, 71, 73, 74, 585, 76, 77, 587, 79, 80, 590, 588, 84, 92, 94, 96, 98, 99, 101, 104, 107, 117, 123, 129, 132, 135, 139, 141, 143, 156, 162, 163, 164, 179, 186, 201, 202, 209, 211, 221, 224, 227, 238, 242, 244, 245, 248, 251, 252, 254, 267, 268, 272, 274, 285, 286, 287, 313, 318, 322, 325, 327, 331, 336, 340, 343, 347, 349, 350, 351, 352, 353, 362, 366, 370, 371, 375, 377, 383, 390, 393, 394, 398, 405, 406, 408, 411, 413, 420, 424, 430, 443, 445, 447, 448, 451, 452, 456, 459, 460, 463, 465, 467, 474, 475, 490, 493, 503, 506] Folder path:dataset-resized/plastic List of indices [69, 292, 434, 411, 392, 33, 131, 61, 254, 390, 231, 242, 334, 195, 404, 108, 49, 250, 15, 458, 428, 200, 222, 312, 391, 393, 2, 357, 229, 137, 370, 479, 118, 303, 53, 163, 16, 12, 14, 333, 278, 5, 196, 352, 111, 217, 372, 464, 271, 114, 478, 225, 474, 284, 120, 177, 119, 347, 113, 473, 236, 149, 445, 214, 285, 329, 52, 96, 323, 371, 152, 62, 381, 171, 452, 365, 257, 437, 260, 344, 98, 156, 146, 301, 256, 259, 202, 302, 18, 246, 125, 410, 207, 213, 341, 89, 188, 281, 360, 346, 378, 192, 45, 431, 340, 261, 56, 84, 267, 396, 190, 251, 446, 241, 23, 158, 361, 315, 304, 297, 373, 332, 88, 87, 258, 117, 7, 103, 277, 385, 426, 208, 264, 427, 296, 181, 422, 138, 338, 353, 459, 3, 197, 263, 67, 266, 288, 106, 219, 29, 247, 187, 481, 429, 355, 397, 212, 249, 183, 389, 178, 1, 276, 354, 314, 170, 235, 308, 435, 450, 91, 282, 300, 93, 47, 283, 476, 17, 37, 43, 9, 232, 8, 144, 128, 345, 57, 95, 349, 421, 36, 86, 82, 306, 480, 359, 140, 151, 233, 165, 255, 243, 59, 13, 160, 198, 176, 216, 97, 133, 376, 130, 262, 467, 460, 11, 116, 10, 204, 75, 19, 83, 454, 470, 424, 472, 115, 135, 167, 310, 102, 173, 148, 206, 377, 169, 162, 110, 408, 189, 77]. .[70, 313, 439, 417, 403, 39, 136, 64, 273, 402, 240, 253, 348, 205, 413, 107, 51, 270, 26, 433, 210, 230, 326, 469, 405, 4, 374, 239, 142, 384, 471, 123, 321, 55, 168, 447, 24, 25, 456, 295, 20, 455, 367, 112, 227, 386, 425, 291, 121, 237, 463, 305, 124, 184, 394, 364, 400, 245, 154, 416, 224, 307, 342, 54, 94, 336, 155, 465, 179, 275, 475, 279, 99, 159, 150, 320, 274, 356, 211, 343, 27, 268, 129, 220, 379, 85, 194, 299, 201, 48, 399, 280, 58, 79, 289, 337, 199, 272, 462, 457, 31, 161, 477, 81, 294, 362, 423, 21, 104, 387, 328, 388, 186, 383, 244, 482, 175, 145, 350, 325] .[6, 22, 28, 30, 32, 34, 35, 38, 40, 41, 42, 44, 46, 50, 60, 63, 65, 66, 68, 71, 72, 73, 74, 76, 78, 80, 90, 92, 100, 101, 105, 109, 122, 126, 127, 132, 134, 139, 141, 143, 147, 153, 157, 164, 166, 172, 174, 180, 182, 185, 191, 193, 203, 209, 215, 218, 221, 223, 226, 228, 234, 238, 248, 252, 265, 269, 286, 287, 290, 293, 298, 309, 311, 316, 317, 318, 319, 322, 324, 327, 330, 331, 335, 339, 351, 358, 363, 366, 368, 369, 375, 380, 382, 395, 398, 401, 406, 407, 409, 412, 414, 415, 418, 419, 420, 430, 432, 436, 438, 440, 441, 442, 443, 444, 448, 449, 451, 453, 461, 466, 468] Folder path:dataset-resized/trash List of indices [35, 17, 66, 31, 127, 116, 121, 98, 54, 25, 63, 4, 115, 107, 50, 56, 78, 130, 99, 1, 90, 58, 137, 93, 103, 30, 76, 14, 41, 126, 3, 108, 84, 70, 2, 49, 88, 28, 55, 114, 106, 68, 29, 57, 64, 71, 112, 45, 91, 87, 95, 59, 38, 124, 129, 72, 13, 24, 85, 16, 43, 65, 119, 111, 128, 39, 37, 120]. .[20, 11, 48, 18, 113, 101, 46, 73, 33, 86, 135, 9, 47, 132, 94, 105, 40, 62, 82, 83, 131, 77, 42, 125, 136, 61, 117, 23, 97, 69, 67, 52, 104, 96] .[5, 6, 7, 8, 133, 10, 134, 12, 15, 19, 21, 22, 26, 27, 32, 34, 36, 44, 51, 53, 60, 74, 75, 79, 80, 81, 89, 92, 100, 102, 109, 110, 118, 122, 123]
I set the seed for both random samples to be 1 for reproducibility. Now that the data's organized, we can get to model training.
Important things before we go any further, fast.ai uses Path extensively for handling image path which makes more convenient to not worry about the directory where the images have to be loaded from.
In short, path is just the working directory where temporary files/models will be saved.
Path is used extensively in fastai reference.
## Get the images from particular path...
path = Path(os.getcwd())/"data"
path
Path('/content/data')
# Checking the counts of plastic, glass in train and validation set..
print("Plastic {}".format(len(os.listdir(str(path)+'/train/plastic'))))
print("Plastic {}".format(len(os.listdir(str(path)+'/valid/plastic'))))
print("Glass tr {}".format(len(os.listdir(str(path)+'/train/glass'))))
print("Glass Vl {}".format(len(os.listdir(str(path)+'/valid/glass'))))
Plastic 241 Plastic 120 Glass tr 242 Glass Vl 124
It seems like we are getting a normal split between train and valid set for glass images. Will look further to improve classification problem by augmentation and so on..
#hide
tr_plastic_imgs = os.listdir(str(path)+'/train/plastic')
os.listdir(str(path) + '/train/glass')
['glass151.jpg', 'glass49.jpg', 'glass106.jpg', 'glass246.jpg', 'glass446.jpg', 'glass359.jpg', 'glass163.jpg', 'glass378.jpg', 'glass1.jpg', 'glass36.jpg', 'glass297.jpg', 'glass117.jpg', 'glass178.jpg', 'glass361.jpg', 'glass82.jpg', 'glass146.jpg', 'glass188.jpg', 'glass340.jpg', 'glass396.jpg', 'glass428.jpg', 'glass15.jpg', 'glass233.jpg', 'glass271.jpg', 'glass276.jpg', 'glass83.jpg', 'glass380.jpg', 'glass494.jpg', 'glass312.jpg', 'glass256.jpg', 'glass87.jpg', 'glass360.jpg', 'glass308.jpg', 'glass93.jpg', 'glass389.jpg', 'glass487.jpg', 'glass258.jpg', 'glass152.jpg', 'glass14.jpg', 'glass385.jpg', 'glass370.jpg', 'glass404.jpg', 'glass52.jpg', 'glass59.jpg', 'glass18.jpg', 'glass67.jpg', 'glass443.jpg', 'glass341.jpg', 'glass75.jpg', 'glass393.jpg', 'glass363.jpg', 'glass434.jpg', 'glass381.jpg', 'glass447.jpg', 'glass86.jpg', 'glass229.jpg', 'glass433.jpg', 'glass91.jpg', 'glass61.jpg', 'glass255.jpg', 'glass103.jpg', 'glass347.jpg', 'glass400.jpg', 'glass394.jpg', 'glass56.jpg', 'glass449.jpg', 'glass247.jpg', 'glass301.jpg', 'glass365.jpg', 'glass197.jpg', 'glass314.jpg', 'glass7.jpg', 'glass196.jpg', 'glass128.jpg', 'glass390.jpg', 'glass200.jpg', 'glass69.jpg', 'glass3.jpg', 'glass111.jpg', 'glass19.jpg', 'glass372.jpg', 'glass235.jpg', 'glass257.jpg', 'glass241.jpg', 'glass212.jpg', 'glass391.jpg', 'glass467.jpg', 'glass130.jpg', 'glass17.jpg', 'glass477.jpg', 'glass296.jpg', 'glass37.jpg', 'glass84.jpg', 'glass97.jpg', 'glass281.jpg', 'glass282.jpg', 'glass5.jpg', 'glass225.jpg', 'glass57.jpg', 'glass114.jpg', 'glass486.jpg', 'glass302.jpg', 'glass476.jpg', 'glass421.jpg', 'glass261.jpg', 'glass95.jpg', 'glass438.jpg', 'glass415.jpg', 'glass492.jpg', 'glass266.jpg', 'glass140.jpg', 'glass463.jpg', 'glass480.jpg', 'glass170.jpg', 'glass11.jpg', 'glass202.jpg', 'glass88.jpg', 'glass118.jpg', 'glass250.jpg', 'glass371.jpg', 'glass436.jpg', 'glass459.jpg', 'glass323.jpg', 'glass285.jpg', 'glass12.jpg', 'glass387.jpg', 'glass418.jpg', 'glass260.jpg', 'glass357.jpg', 'glass355.jpg', 'glass352.jpg', 'glass113.jpg', 'glass204.jpg', 'glass452.jpg', 'glass369.jpg', 'glass300.jpg', 'glass485.jpg', 'glass491.jpg', 'glass315.jpg', 'glass125.jpg', 'glass47.jpg', 'glass181.jpg', 'glass473.jpg', 'glass216.jpg', 'glass414.jpg', 'glass16.jpg', 'glass408.jpg', 'glass292.jpg', 'glass303.jpg', 'glass89.jpg', 'glass133.jpg', 'glass65.jpg', 'glass119.jpg', 'glass264.jpg', 'glass373.jpg', 'glass153.jpg', 'glass45.jpg', 'glass278.jpg', 'glass108.jpg', 'glass267.jpg', 'glass219.jpg', 'glass156.jpg', 'glass183.jpg', 'glass23.jpg', 'glass376.jpg', 'glass165.jpg', 'glass288.jpg', 'glass131.jpg', 'glass203.jpg', 'glass208.jpg', 'glass242.jpg', 'glass320.jpg', 'glass10.jpg', 'glass304.jpg', 'glass222.jpg', 'glass411.jpg', 'glass220.jpg', 'glass283.jpg', 'glass160.jpg', 'glass33.jpg', 'glass263.jpg', 'glass489.jpg', 'glass429.jpg', 'glass31.jpg', 'glass8.jpg', 'glass79.jpg', 'glass435.jpg', 'glass344.jpg', 'glass334.jpg', 'glass493.jpg', 'glass259.jpg', 'glass177.jpg', 'glass375.jpg', 'glass262.jpg', 'glass243.jpg', 'glass392.jpg', 'glass144.jpg', 'glass137.jpg', 'glass458.jpg', 'glass407.jpg', 'glass116.jpg', 'glass62.jpg', 'glass43.jpg', 'glass55.jpg', 'glass329.jpg', 'glass53.jpg', 'glass158.jpg', 'glass441.jpg', 'glass251.jpg', 'glass120.jpg', 'glass198.jpg', 'glass439.jpg', 'glass346.jpg', 'glass277.jpg', 'glass207.jpg', 'glass2.jpg', 'glass333.jpg', 'glass115.jpg', 'glass232.jpg', 'glass213.jpg', 'glass249.jpg', 'glass332.jpg', 'glass171.jpg', 'glass465.jpg', 'glass214.jpg', 'glass384.jpg', 'glass9.jpg', 'glass217.jpg', 'glass231.jpg', 'glass472.jpg', 'glass138.jpg', 'glass265.jpg', 'glass149.jpg', 'glass284.jpg', 'glass96.jpg', 'glass13.jpg', 'glass236.jpg', 'glass187.jpg', 'glass98.jpg', 'glass29.jpg', 'glass364.jpg', 'glass195.jpg', 'glass338.jpg']
# Viewing the image in the dataset
tr_glass_imgs = os.listdir(str(path)+'/train/glass')
img = Image.open(str(path)+'/train/glass/'+tr_glass_imgs[3])
In the previous version of this notebook, glass was more misclassified as metal or plastic.
glass_imgs = (path/'train/glass').ls()
im = Image.open(glass_imgs[1])
# converting the first image to tensors...
first_glass = tensor(im)
print(first_glass[1:4, 4:10])
print(first_glass.shape)
tensor([[[230, 210, 186],
[229, 209, 185],
[229, 209, 185],
[229, 209, 185],
[228, 208, 184],
[228, 208, 184]],
[[229, 209, 185],
[229, 209, 185],
[229, 209, 185],
[229, 209, 185],
[228, 208, 184],
[228, 208, 184]],
[[229, 209, 185],
[229, 209, 185],
[228, 208, 184],
[228, 208, 184],
[228, 208, 184],
[228, 208, 184]]], dtype=torch.uint8)
torch.Size([384, 512, 3])
We have converted our first image of glass to tensor....
Next steps:
We calculate the mean of the sample glass image and see if we can get closer image of tensor of plastic as well as metal as they were the most misclassified items for our waste classifier
a = torch.randn(20, 4, 4)
print(a)
random.seed(1)
mean_a = a.mean(0)
print(mean_a.shape)
tensor([[[ 1.9041, -0.6623, -0.0740, -1.8308],
[ 0.0620, -0.2726, 1.5847, -2.0998],
[-1.8451, -0.5164, -0.8150, 0.2383],
[-0.0154, -1.3963, -0.2346, 0.6368]],
[[-0.4682, 0.7713, -1.9177, 0.5771],
[ 1.3979, 0.8343, 0.2862, 0.4237],
[ 1.6290, -0.2072, -0.1179, -0.9172],
[-0.6613, -1.1318, 0.1150, 2.8163]],
[[-0.8620, -0.8489, -0.4204, 0.7048],
[ 0.1674, -0.9637, -1.4947, 0.9189],
[-0.3080, -3.3565, 1.1957, -0.8564],
[ 1.2134, -0.4157, 0.0896, -0.1226]],
[[-0.5561, 1.2022, -0.3723, 0.9290],
[ 1.0448, 0.6112, -1.1486, -1.2055],
[-0.5671, -0.7027, -1.4796, 1.7317],
[-0.5427, 0.8289, -1.6866, -1.9065]],
[[ 0.7942, -0.7050, -1.6597, 1.6873],
[ 1.5678, -0.0256, 0.2248, 1.2104],
[-1.0935, -0.5420, -0.2103, -2.3326],
[ 0.3916, 0.2181, -2.0667, -0.0830]],
[[ 1.0960, 0.4500, 0.4413, 0.5867],
[-1.6896, -0.7949, 0.0723, 0.4627],
[-0.6459, -1.1081, -0.6236, 0.8584],
[-0.3111, 0.1923, 0.4706, -1.2359]],
[[-2.0230, 0.8408, 0.7278, -0.4549],
[-2.4891, -0.0443, 0.7275, -2.5718],
[-1.6148, -0.9621, 0.8803, 0.2965],
[-0.5933, -0.5933, -0.6970, -0.8499]],
[[-0.4422, 0.0171, -2.3430, 0.7095],
[-1.2889, 1.5792, 1.1738, 0.6752],
[-1.3570, 0.1963, -0.9830, 1.6627],
[ 0.6076, 0.0658, 0.3313, -0.1094]],
[[-0.0680, -0.5092, 1.9973, -0.2931],
[-0.8382, -1.0782, 0.4462, -0.0646],
[-0.5631, 1.4206, 1.7943, 1.8898],
[ 0.3165, 0.2293, 0.7299, -1.4330]],
[[ 1.2045, 0.6211, -1.1297, 0.7161],
[ 2.2381, 0.3150, 1.0000, 2.0494],
[-0.1351, 1.8014, 0.5025, -0.0968],
[-1.1127, -0.5012, -0.8887, 0.2562]],
[[ 0.2549, -0.7717, 0.4916, 1.4921],
[-0.7116, -0.1897, -1.6547, -0.8994],
[ 0.1157, 0.8570, 0.7750, -1.0252],
[ 0.4516, 1.6063, -0.4316, -1.5365]],
[[ 0.0418, -0.4215, -0.0071, 0.3110],
[ 0.2383, -1.7163, -0.3315, 2.7632],
[ 0.6609, 0.0607, -0.9238, -0.9787],
[ 0.6828, -1.0436, -0.6931, 0.1954]],
[[-0.3110, 0.2085, 0.4414, -0.5819],
[ 1.5436, -1.5502, -0.8435, -0.7709],
[-0.7234, -0.8946, -0.8756, -0.9477],
[ 0.0690, -2.8165, 0.9283, -0.6112]],
[[-0.1419, -0.1804, 0.4074, -0.2060],
[ 1.1673, -0.3224, 0.1169, -1.5055],
[ 0.4926, 1.4967, -0.1806, 1.6904],
[-0.3416, 0.1862, 0.4881, 0.3701]],
[[-0.2108, 0.2191, -1.1873, -1.0986],
[ 1.4215, -0.3173, 0.0849, -0.7219],
[-0.3716, -0.9612, -0.3766, -0.9942],
[ 0.6933, 0.8511, 1.3226, -0.3613]],
[[ 0.1691, -1.3188, 0.9293, -1.1307],
[ 0.9253, -1.1454, -0.6973, -0.0113],
[-1.9940, -0.9999, -1.1215, -0.8117],
[-0.6660, 0.9826, 1.1000, 1.2608]],
[[ 1.2468, -0.1739, 0.6360, 0.4991],
[-0.0972, 1.2482, 0.7075, 2.0349],
[ 1.3595, -0.5676, 1.0609, -0.4931],
[ 0.8091, 0.3110, -0.4240, -0.5833]],
[[ 1.5777, -1.2872, -0.5809, -0.3320],
[ 1.2174, -2.0281, -0.2087, -0.7707],
[ 1.2203, 0.3608, -2.1707, 0.9702],
[ 1.6480, -0.1294, 0.3269, 1.0901]],
[[-0.0916, 0.2605, 0.1513, 0.3333],
[ 0.8916, -1.0663, -0.6575, -0.8669],
[-0.4498, -0.8968, 1.9549, 0.6996],
[ 0.6076, 0.1456, 1.4862, 0.0192]],
[[-0.2001, 0.9258, 0.2802, -0.7186],
[ 0.1723, -1.8113, 0.8975, -1.8717],
[-0.7931, -0.5257, -0.3883, -0.5874],
[ 0.6360, -0.6921, -0.7885, 1.1872]]])
torch.Size([4, 4])
The above code explains how 20 random matrices of 4X4 are stacked up and along the first axis we take the mean which is like all 20 (4 X 4) matrices are computed along X and Y axis and we get a mean random matrix of (4X4).
We will use similar process for concatenating the matrices of 384 X 512 X 3 images to the folder length. For example if training set has 240 images of 384 X 512 X 3 channels, this will be combined or stacked up to shape (240, 384, 512, 3) image size. Hence, we can now compute the average pixels of all these 240 images to form one single image formed by mean of all 240 images and that can be compared to a new image with size (384X512X3). This is the base line model which might not give good results comparatively to the state of the art transfer learning models which are used later in the notebook cells.
glass_tensors = [tensor(Image.open(g_img)) for g_img in glass_imgs]
print(len(glass_tensors))
plastic_imgs = (path/'train/plastic').ls()
plastic_tensors = [tensor(Image.open(g_img)) for g_img in plastic_imgs]
print(len(plastic_tensors))
metal_imgs = (path/'train/metal').ls()
metal_tensors = [tensor(Image.open(g_img)) for g_img in metal_imgs]
print(len(metal_tensors))
# stacking up all the images of glass, plastic and metal...
glass_stack = torch.stack(glass_tensors).float()/255
print(glass_stack.shape, glass_stack.ndim)
plastic_stack = torch.stack(plastic_tensors).float()/255
print(plastic_stack.shape, plastic_stack.ndim)
metal_stack = torch.stack(metal_tensors).float()/255
print(metal_stack.shape, metal_stack.ndim)
mean_glass_tensor = glass_stack.mean(0)
print("mean shape of glass tensor ==>", mean_glass_tensor.shape)
mean_metal_tensor = metal_stack.mean(0)
show_image(mean_metal_tensor)
show_image(mean_glass_tensor)
242 241 205 torch.Size([242, 384, 512, 3]) 4 torch.Size([241, 384, 512, 3]) 4 torch.Size([205, 384, 512, 3]) 4 mean shape of glass tensor ==> torch.Size([384, 512, 3])
<matplotlib.axes._subplots.AxesSubplot at 0x7ff696e7e0d0>
mean_plastic_tensor = plastic_stack.mean(0)
show_image(mean_plastic_tensor)
<matplotlib.axes._subplots.AxesSubplot at 0x7ff699b4d690>
# Example 1: To check which is the closest to mean pizels of glass image....
diff1 = (glass_stack[1] - mean_glass_tensor).abs().mean()
print(diff1)
diff2 = (glass_stack[1] - mean_plastic_tensor).abs().mean()
print(diff2)
diff3 = (glass_stack[1] - mean_metal_tensor).abs().mean()
print(diff3)
# Clearly the lowest distance can be found for the glass tensors, then plastic and metal follow along in ranking..
# Hence this ia a baseline model to classify this as a glass tensor
tensor(0.0943) tensor(0.0994) tensor(0.1026)
# takes the mean across x and y axis
def glass_distance(a,b): return (a-b).abs().mean((-1, -2))
def is_glass(x): return glass_distance(x,mean_glass_tensor[:, :, 0]) < glass_distance(x,mean_plastic_tensor[:, :, 0])
is_glass(plastic_stack[-1][:, :, 0]), is_glass(plastic_stack[-1][:, :, 0]).float()
(tensor(False), tensor(0.))
valid_path = (path/'valid/glass').ls()
print(len(valid_path))
valid_path = (path/'valid/plastic').ls()
print(len(valid_path))
valid_path = (path/'train/glass').ls()
print(len(valid_path))
valid_path = (path/'train/plastic').ls()
print(len(valid_path))
#######
valid_path = (path/'valid/plastic').ls()
plastic_imgs = [Image.open(img) for img in valid_path]
plastic_valid_tens = torch.stack([tensor(Image.open(img)) for img in valid_path])
plastic_valid_tens = plastic_valid_tens.float()/255
valid_plastic_tens = plastic_valid_tens[:, :, :, 0]
trues = [i for i in is_glass(valid_plastic_tens) if i == True]
false = [i for i in is_glass(valid_plastic_tens) if i == False]
total = trues + false
all = len(total)
print(len(trues), len(false))
124 120 242 241 68 52
Accuracy of our initial base line model constructed on the fact of comparing average image pixels with new glass or metal image.
valid_path = (path/'valid/glass').ls()
glass_imgs = [Image.open(img) for img in valid_path]
#print(valid_path)
glass_valid_tens = torch.stack([tensor(Image.open(img)) for img in valid_path])
glass_valid_tensors = glass_valid_tens.float()/255
valid_glass_tens = glass_valid_tensors[:, :, :, 0]
accuracy_glass = is_glass(valid_glass_tens).float() .mean()
accuracy_plastic = (1 - is_glass(valid_plastic_tens).float()).mean()
# simple accuracy formula to compute...
print(accuracy_glass,accuracy_plastic,(accuracy_plastic+accuracy_glass)/2)
A simplistic implementation of creating a base line model to see why model can get confuse between metal, glass images because the distances of these tensors are very close to each other. In simpler terms these are most correlated ones..
Let's try to put our model into implementation by first apply augmented transforms to it and then putting these details into our model. Here are some details about the function
doc(aug_transforms)
aug_transforms[source]
aug_transforms(mult=1.0,do_flip=True,flip_vert=False,max_rotate=10.0,min_zoom=1.0,max_zoom=1.1,max_lighting=0.2,max_warp=0.2,p_affine=0.75,p_lighting=0.75,xtra_tfms=None,size=None,mode='bilinear',pad_mode='reflection',align_corners=True,batch=False,min_scale=1.0)
Utility func to easily create a list of flip, rotate, zoom, warp, lighting transforms.
| Type | Default | |
|---|---|---|
mult |
float |
1.0 |
do_flip |
bool |
True |
flip_vert |
bool |
False |
max_rotate |
float |
10.0 |
min_zoom |
float |
1.0 |
max_zoom |
float |
1.1 |
max_lighting |
float |
0.2 |
max_warp |
float |
0.2 |
p_affine |
float |
0.75 |
p_lighting |
float |
0.75 |
xtra_tfms |
NoneType |
None |
size |
NoneType |
None |
mode |
str |
bilinear |
pad_mode |
str |
reflection |
align_corners |
bool |
True |
batch |
bool |
False |
min_scale |
float |
1.0 |
Next Important class is ImageDataLoaders.
ImageDataLoaders - Wrapper around the DataLoaders with factory methods for computer visions.
In simpler terms and ImageDataLoaders have several helper functions which can easily load the data as a DataLoader object.
According to the docs of fast.ai ImageDataLoader Doc
This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:
doc(ImageDataLoaders)
class ImageDataLoaders[source]
ImageDataLoaders(*loaders,path='.',device=None) ::DataLoaders
Basic wrapper around several DataLoaders with factory methods for computer vision problems
| Type | Default | |
|---|---|---|
loaders |
||
path |
str |
. |
device |
NoneType |
None |
tfms = aug_transforms(do_flip=True,flip_vert=True)
data= ImageDataLoaders.from_folder(path, train = "train", valid = "valid",
batch_tfms=[*tfms, Normalize.from_stats(*imagenet_stats)],bs = 16)
Important Note from Transformations perspective
Idea is to bring each pixel value close to the center, so that data dimensions are of approximately the same scale. We'd like in this process for each feature to have a similar range so that our gradients don't go out of control (and that we only need one global learning rate multiplier).
For example in case of RGB channels, we will do this process for each of the channel by simple formula. Here is a demonstration of how to perform normalization with python using numpy library.
X /= np.std(X, axis = 0)
image source - https://cs231n.github.io/neural-networks-2/
The batch size bs is how many images you'll train at a time. Similarly, we can specify the valid batch size, which defaults to bs we have provided. Smaller bs will work for computers with less memory.
You can use aug_transforms() function to augment your data. I'll compare the results from flipping images horizontally and vertically.
print(data.vocab)
['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
show_batch function fast.ai
A sample image can be seen from data.show_batch method
Takes into argument figure size tuple
This function display the batches of images to quickly glance at the data we are playing around with.
Some most common used arguments
type(data)
fastai.data.core.DataLoaders
# Applying show_batch function to our data loader object 'data' and visualize some of the images
data.show_batch(figsize=(10,8))
If you run the program with CUDA_LAUNCH_BLOCKING=1, this will help get a more exact stack trace
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
# just one line to create our state of the art transfer learning model
learn = cnn_learner(data,models.resnet34,metrics=error_rate)
Cite - Notebook link
A residual neural network is a convolutional neural network (CNN) with lots of layers. In particular, resnet34 is a CNN with 34 layers that's been pretrained on the ImageNet database. A pretrained CNN will perform better on new image classification tasks because it has already learned some visual features and can transfer that knowledge over (hence transfer learning).
Since they're capable of describing more complexity, deep neural networks should theoretically perform better than shallow networks on training data. In reality, though, deep neural networks tend to perform empirically worse than shallow ones.
Resnets were created to circumvent this glitch using a hack called shortcut connections. If some nodes in a layer have suboptimal values, you can adjust weights and bias; if a node is optimal (its residual is 0), why not leave it alone? Adjustments are only made to nodes on an as-needed basis (when there's non-zero residuals).
When adjustments are needed, shortcut connections apply the identity function to pass information to subsequent layers. This shortens the neural network when possible and allows resnets to have deep architectures and behave more like shallow neural networks. The 34 in resnet34 just refers to the number of layers.
Here is an interesting links for RESNET architecture
https://blog.roboflow.com/custom-resnet34-classification-model/
doc(learn.model)
5, inplace=False)
(8): Linear(in_features=512, out_features=6, bias=False)
)
)[source]
5, inplace=False) (8): Linear(in_features=512, out_features=6, bias=False) ) )(*input, **kwargs)
A sequential container.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an OrderedDict of modules can be
passed in. The forward() method of Sequential accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a Sequential provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
Sequential applies to each of the modules it stores (which are
each a registered submodule of the Sequential).
What's the difference between a Sequential and a
:class:torch.nn.ModuleList? A ModuleList is exactly what it
sounds like--a list for storing Module s! On the other hand,
the layers in a Sequential are connected in a cascading way.
Example::
# Using Sequential to create a small model. When `model` is run,
# input will first be passed to `Conv2d(1,20,5)`. The output of
# `Conv2d(1,20,5)` will be used as the input to the first
# `ReLU`; the output of the first `ReLU` will become the input
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
# taking one image sample
tr_glass_imgs = os.listdir(str(path)+'/train/glass')[0]
ig=PILImage(PILImage.create(str(path)+'/train/glass/'+tr_glass_imgs).resize((600,400)))
type(array(ig))
numpy.ndarray
# Small example - Consider a tensor
sample_tensor = torch.randn(5,4)
# next step is to use permute method
after_permute_operation = sample_tensor.permute(1, 0)
# permute(1, 0) -> this is essentially swapping up the two axis
print(after_permute_operation.shape)
assert after_permute_operation.shape[0] == 4
torch.Size([4, 5])
array(ig).shape
(400, 600, 3)
# Using TensorImage class to convert the numpy.ndarray to tensors
# permute method - Used to reorder or reorganize the dimensions of an image
timg = TensorImage(array(ig)).permute(2,0,1).float()/255.
# Below function expands the dimension to the new batch size to an existing image shape.
def _batch_ex(bs): return TensorImage(timg[None].expand(bs, *timg.shape).clone())
# A simple and a short glance into fastai classes and functions related to transformations..
tfms = aug_transforms(do_flip=True)
for i in tfms:
# Tfms which is a transform object takes into account two class when we pass an argument, do_flip=True
print("Class ===>>>>", i, i.__getattribute__)
# We can now create an object of all the transformations class and then pass our tensor form of images, shown in below code...
Class ===>>>> Flip -- {'size': None, 'mode': 'bilinear', 'pad_mode': 'reflection', 'mode_mask': 'nearest', 'align_corners': True, 'p': 0.5}:
encodes: (TensorImage,object) -> encodes
(TensorMask,object) -> encodes
(TensorBBox,object) -> encodes
(TensorPoint,object) -> encodes
decodes: <method-wrapper '__getattribute__' of Flip object at 0x7ff627724d10>
Class ===>>>> Brightness -- {'max_lighting': 0.2, 'p': 1.0, 'draw': None, 'batch': False}:
encodes: (TensorImage,object) -> encodes
decodes: <method-wrapper '__getattribute__' of Brightness object at 0x7ff627724850>
# Performing image transformations through function..
y = _batch_ex(2)
for t in tfms:
# split_idx = 0 refers to we are passing train image.. split_idx = 1, refers to validation in fastai
y = t(y, split_idx=0)
_,axs = plt.subplots(1,2, figsize=(10,8))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)
Fast Ai uses one cycle policy which simply means, that learning rate will first start with low value, bouncing back to the large values and then being stable with value lower than the initial LR value, this has shown impressive results as it means we are not ending with either a slow or a very high LR.
To choose the best LR, one must look at the point where loss curve is the steepest. Please note that steepest point doesn't mean the point of minimum loss. It means point where loss is dropping faster..
image source - https://iconof.com/1cycle-learning-rate-policy/
# start_lr = starting learning rate
# end_lr = maximum learning rate at which we want the model to stop finding LR.
learn.lr_find(start_lr=1e-6,end_lr=1e1)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
SuggestedLRs(valley=0.0012022644514217973)
lr_min, lr_steep = learn.lr_find(suggest_funcs=(minimum, steep))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
# here is a simple way to find some random number between two numbers in an interval, which can help in guessing what might be the optimal LR
import numpy as np
np.random.uniform(1e-4, 1e-3)
0.0007015538560113359
# fit_one_cycle - method to run the model to some certain epochs, lower epochs
# One epoch - complete cycle of forward propagation/back propagation
learn.fit_one_cycle(20, lr_max=5e-03)
| epoch | train_loss | valid_loss | error_rate | time |
|---|---|---|---|---|
| 0 | 1.626455 | 0.673134 | 0.244833 | 01:40 |
| 1 | 0.984484 | 0.516873 | 0.189189 | 01:40 |
| 2 | 0.834137 | 0.616733 | 0.178060 | 01:40 |
| 3 | 0.862364 | 0.832932 | 0.243243 | 01:40 |
| 4 | 0.840459 | 0.745179 | 0.225755 | 01:40 |
| 5 | 0.778823 | 0.542830 | 0.181240 | 01:40 |
| 6 | 0.682324 | 0.816168 | 0.244833 | 01:40 |
| 7 | 0.638330 | 0.502149 | 0.155803 | 01:40 |
| 8 | 0.557140 | 0.393592 | 0.117647 | 01:40 |
| 9 | 0.572882 | 0.529694 | 0.165342 | 01:40 |
| 10 | 0.462756 | 0.340510 | 0.109698 | 01:40 |
| 11 | 0.385636 | 0.312294 | 0.112878 | 01:40 |
| 12 | 0.369283 | 0.339354 | 0.114467 | 01:40 |
| 13 | 0.298273 | 0.284633 | 0.084261 | 01:40 |
| 14 | 0.274431 | 0.237938 | 0.076312 | 01:40 |
| 15 | 0.236636 | 0.229560 | 0.071542 | 01:40 |
| 16 | 0.165803 | 0.201756 | 0.068362 | 01:40 |
| 17 | 0.143642 | 0.204025 | 0.069952 | 01:40 |
| 18 | 0.144287 | 0.192077 | 0.066773 | 01:40 |
| 19 | 0.185571 | 0.196076 | 0.071542 | 01:39 |
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
The model ran for 20 epochs giving us the minimum loss of 0.066 which is better than the previous model, which was around 0.08. Hence, we are able to reduce the loss and increase the accuracy which will see later.
interp = ClassificationInterpretation.from_learner(learn)
losses,idxs = interp.top_losses()
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
plt.figure(figsize=(10, 8))
interp.plot_top_losses(4, nrows=2)
plt.show()
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
<Figure size 720x576 with 0 Axes>
interp.plot_top_losses(10, figsize=(15,11))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
The images here are the ones after the removal of some over exposure of light in images, which was not the case with previous notebook of collindching. This has created an impact as now our model is less confused and hence the loss decreased.
# Here's a documentation of plot_top_losses function would look like...
doc(interp.plot_top_losses)
Interpretation.plot_top_losses[source]
Interpretation.plot_top_losses(k,largest=True, **kwargs)
Show k largest(/smallest) preds and losses. k may be int, list, or range of desired results.
| Type | Default | |
|---|---|---|
k |
||
largest |
bool |
True |
kwargs |
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
In the previous version of this notebook by collindching, The model often confused plastic for glass and confused metal for glass. The list of most confused images is below. Let's see are we able to reduce the overall misclassification error for the categories or not later.
# checking where our model got most confused in classifying...
interp.most_confused(min_val=2)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
[('glass', 'metal', 7),
('cardboard', 'paper', 5),
('glass', 'plastic', 5),
('plastic', 'metal', 4),
('plastic', 'paper', 4),
('trash', 'paper', 4),
('metal', 'paper', 2),
('metal', 'plastic', 2),
('plastic', 'trash', 2)]
To see how this mode really performs, we need to make predictions on test data. First, I'll make predictions on the test data using the learner.get_preds() method.
Note: learner.predict() only predicts on a single image, while learner.get_preds() predicts on a set of images. I highly recommend reading the documentation to learn more about predict() and get_preds().
doc(learn.predict)
Learner.predict[source]
Learner.predict(item,rm_type_tfms=None,with_input=False)
Prediction on item, fully decoded, loss function decoded and probabilities
| Type | Default | |
|---|---|---|
item |
||
rm_type_tfms |
NoneType |
None |
with_input |
bool |
False |
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
In the above cells, we have defined data loader for our train and validation set..
However, for our model to predict on the new image data, we will have load into a dataloader object.
We can simply pass the images of test folder into get_image_files function, which will traverse the test folder and get us all the images in test folder.
get_preds is the function which gives us the probabilities of our image being of each class
#get predictions..
test_dl = data.test_dl(get_image_files(os.path.join(path, 'test')))
preds = learn.get_preds(dl=test_dl)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
print(preds[0].shape)
preds[0]
torch.Size([631, 6])
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
TensorBase([[1.7182e-04, 7.1475e-05, 1.5689e-03, 2.3538e-05, 9.9789e-01, 2.7502e-04],
[8.1237e-08, 5.9949e-08, 1.6182e-08, 9.9999e-01, 1.9660e-07, 1.1921e-05],
[1.3310e-06, 6.2397e-03, 9.9346e-01, 4.1719e-07, 2.6252e-04, 3.6827e-05],
...,
[5.1307e-06, 5.3394e-05, 9.9495e-01, 1.8174e-04, 9.8943e-05, 4.7127e-03],
[1.7216e-07, 1.0502e-05, 7.9543e-04, 7.6339e-06, 9.9208e-05, 9.9909e-01],
[8.3721e-06, 5.6436e-08, 6.0167e-08, 9.9991e-01, 1.5351e-05, 6.4679e-05]])
Simple approach - For one set of image, get the maximum probability value among 6 classes.
Choose the one with maximum value along the 1 axis i.e columns
Rows - Probabilities for other sample of images
## saves the index (0 to 5) of most likely (max) predicted class for each image
max_idxs = np.asarray(np.argmax(preds[0],axis=1))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
# Append the class labels with index found from maximum probability class..
max_idxs = np.asarray(np.argmax(preds[0],axis=1))
classes = data.vocab
print(classes)
yhat = []
for max_idx in max_idxs:
yhat.append(classes[max_idx])
['cardboard', 'glass', 'metal', 'paper', 'plastic', 'trash']
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
# A quick check on what get_image_files does...
l = get_image_files(os.path.join(path, 'test'))
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
from PIL import Image
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
Image.open(l[0])
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
# Validate against the test set...
learn.validate(dl=test_dl)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
(#2) [None,None]
y = []
## convert POSIX paths to string first
for label_path in test_dl.items:
y.append(str(label_path))
# then extract waste type from file path
pattern = re.compile("([a-z]+)[0-9]+")
for i in range(len(y)):
y[i] = pattern.search(y[i]).group(1)
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
## predicted values
print(yhat[0:30])
## actual values
print(y[0:30])
df_dict = {'actual': y, 'predicted': yhat}
pd.DataFrame(df_dict)
['plastic', 'paper', 'metal', 'paper', 'cardboard', 'cardboard', 'paper', 'trash', 'cardboard', 'glass', 'plastic', 'paper', 'paper', 'cardboard', 'glass', 'plastic', 'glass', 'plastic', 'plastic', 'glass', 'paper', 'glass', 'paper', 'trash', 'cardboard', 'paper', 'plastic', 'metal', 'paper', 'paper'] ['plastic', 'paper', 'metal', 'paper', 'cardboard', 'cardboard', 'paper', 'trash', 'cardboard', 'glass', 'plastic', 'plastic', 'paper', 'cardboard', 'glass', 'plastic', 'glass', 'plastic', 'plastic', 'glass', 'paper', 'glass', 'paper', 'trash', 'cardboard', 'paper', 'plastic', 'metal', 'paper', 'paper']
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
| actual | predicted | |
|---|---|---|
| 0 | plastic | plastic |
| 1 | paper | paper |
| 2 | metal | metal |
| 3 | paper | paper |
| 4 | cardboard | cardboard |
| ... | ... | ... |
| 626 | plastic | plastic |
| 627 | glass | glass |
| 628 | metal | metal |
| 629 | trash | trash |
| 630 | paper | paper |
631 rows × 2 columns
It looks the first five predictions match up! (check)
How did we end up doing? Again we can use a confusion matrix to find out.
cm = confusion_matrix(y,yhat)
print(cm)
[[ 96 0 2 2 1 0] [ 0 111 5 0 6 0] [ 0 5 98 0 0 0] [ 1 0 0 145 1 2] [ 0 2 0 1 117 1] [ 0 0 0 5 3 27]]
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
df_cm = pd.DataFrame(cm,waste_types,waste_types)
plt.figure(figsize=(10,8))
sns.heatmap(df_cm,annot=True,fmt="d",cmap="YlGnBu")
<matplotlib.axes._subplots.AxesSubplot at 0x7f084460cdd0>
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
correct = 0
for r in range(len(cm)):
for c in range(len(cm)):
if (r==c):
correct += cm[r,c]
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
accuracy = correct/sum(sum(cm))
accuracy
0.9413629160063391
/usr/lib/python3.7/json/encoder.py:257: UserWarning: date_default is deprecated since jupyter_client 7.0.0. Use jupyter_client.jsonutil.json_default. return _iterencode(o, 0)
We ended up achieving 94.1% accuracy which is slightly better than previous notebook. Also, we were able to work on the next steps mentioned by collindching's notebook and reduce misclassification error as well.
Collindching's version of CM Vs My version
My Version
Google Collaboratory Link for the code and experiments
https://github.com/amay1212/Waste-Sorting/blob/master/Waste_Sorter%20Extended.ipynb
# hide
## delete everything when you're done to save space
# shutil.rmtree("data")
# shutil.rmtree('dataset-resized')
To improve accuracy even further for the misclassified results.
To try and test this model with back rep using the below
To see if we can distinguish more clearly between dry and wet waste in particular.